# -*- coding: utf-8 -*-
"""
Created on Fri Jul 28 12:12:47 2023

@author: Andrei Sontag

Code to perform the curvature analysis of the voting data for the curved manifold
projection. This corresponds to the figure in Section SI.6.1 of the Supplementary
Information.
"""

# importing required libraries
import numpy as np
import pandas as pd
import os, fnmatch
import matplotlib.pyplot as plt

plt.close('all')

os.chdir(r'./')

# Find data files
fileOfDirectory = os.listdir('.')
pattern = r'*.csv'
files = []
for filename in fileOfDirectory:
        if fnmatch.fnmatch(filename, pattern):
                files.append(filename)

# Show the name of the files
print(files)

# Defines the ratios to bin the data, since each group size has different ratios.
# This allows to unify the data. Here we select 30 bins for the data.
ratios = np.arange(-1,1+1/30,1/30)

# Defining arrays to store counts, averages and standard deviations for each bin.
count_r = np.zeros_like(ratios)
mean_Abr = np.zeros_like(ratios)
sqrmean_Abr = np.zeros_like(ratios)

# stores the data for all votes for each option in each experiment at each round
dataA = np.array([])
dataB = np.array([])

for i,file in enumerate(files):
    # read files and stores in a data frame
    df = pd.read_csv(file)
    
    # game has 120 rounds + 1 to make the code more streamlined
    nrounds = 121
    
    total_Avotes=[]
    total_Bvotes=[]
    total_Absvotes=[]
    # Defines the headings to read from the dataframe
    for k in np.arange(1,nrounds):
        total_Avotes = np.append(total_Avotes,r"my_voting.{0:.0f}.group.A_votes".format(k))
        total_Bvotes = np.append(total_Bvotes,r"my_voting.{0:.0f}.group.B_votes".format(k))
        total_Absvotes = np.append(total_Absvotes,r"my_voting.{0:.0f}.group.abstain_votes".format(k))
    
    # gets the data from the data frame about the number of A votes, B votes and Abstentions at each round
    tAvotes = df[total_Avotes].to_numpy()
    tAvotes = tAvotes[0,:]
    tBvotes = df[total_Bvotes].to_numpy()
    tBvotes = tBvotes[0,:]
    tAbsvotes = df[total_Absvotes].to_numpy()
    tAbsvotes = tAbsvotes[0,:]
    
    # Defines the group size
    N = int(tAvotes[0] + tBvotes[0] + tAbsvotes[0])
    
    # brings the data from each experiment in a single array
    dataA = np.append(dataA,tAvotes/N)
    dataB = np.append(dataB,tBvotes/N)
    
    mean_abs = np.zeros(2*N+1)
    count_mabs = np.zeros(2*N+1)
    
    # for each round
    for k in np.arange(0,nrounds-1):
        # computes sum and difference of A and B votes in the round
        diff = int(tAvotes[k]-tBvotes[k])
        sumv = int(tAvotes[k]+tBvotes[k])

        # computes the projection of the point         
        round_ratio = (tAvotes[k]-tBvotes[k])/N
        # finds the index the projection belongs to
        idx = np.argmin(abs(ratios-round_ratio))
        
        # adds the proportion of abstentions to the corresponding bin
        mean_Abr[idx] += tAbsvotes[k]/N
        # also adds to the squared mean
        sqrmean_Abr[idx] += (tAbsvotes[k]/N)**2
        # adds to the counter of the number of data points in that bin
        count_r[idx] += 1
        
#%%
### DATA ANALYSIS

# Symmetrise data
mean_Abr2 = (mean_Abr+mean_Abr[::-1])
# Find average number of abstentions per bin
mean_Abr = mean_Abr/count_r
mean_Abr2 = mean_Abr2/(count_r+count_r[::-1])

# Find standard deviation of the data
sqr_Abr = np.sqrt((sqrmean_Abr-count_r*mean_Abr**2)/(count_r-1))
# Symmetrise error
sqr_Abr2 = np.sqrt(((sqrmean_Abr+sqrmean_Abr[::-1])-(count_r+count_r[::-1])*mean_Abr2**2)/(count_r+count_r[::-1]-1))

# Error of the mean (used to define 95% confidence intervals)
merror = sqr_Abr2/np.sqrt(count_r+count_r[::-1])

# Get rid of points with <2 data points, those dont provide enough data to estimate averages and errors
mAbr = mean_Abr2[count_r+count_r[::-1]>1]
sqrAbr = sqr_Abr2[count_r+count_r[::-1]>1]
mEr = merror[count_r+count_r[::-1]>1]
rats = ratios[count_r+count_r[::-1]>1]

#%% Data fitting
from scipy.optimize import curve_fit

# Straight-line in euclidian coordinates
def func_str(x,a):
    return 1-a-x

# hyperbola in euclidian coordinates
def func_curv(x,a):
    return (1-x)/(1+a*x)

# 1-x-y - a*xy = b (mix between line and curved)
def func_curvb(x,a,b):
    return (1-x-b)/(1+a*x)

# straight-line in projected coordinates
def funline(z,a):
    size = len(z)
    return size*[a]

# hyperbola in projected coordinates
def funcurv(z,a):
    return 1-(-2+np.sqrt(4+4*a+(a*z)**2))/a

# 1-x-y - a*xy = b in projected coordinates
def funcurvb(z,a,b):
    return 1-(-2+np.sqrt(4+4*a*(1-b)+(a*z)**2))/a

# ignores points at the edges with very low data counts
xdata = rats # alignment values
ydata = mAbr # mean number of abstentions in each bin
dy = mEr # error propagation of the data in the new variable

# Removing outliers (points without enough data)
xdata = np.delete(xdata,[0,-1,1,-2])
ydata = np.delete(ydata,[0,-1,1,-2])
dy = np.delete(dy,[0,-1,1,-2])

# least squares fit in the new variables, should minimise distance in the bin, rather than in euclidian coordinates
line_fit = curve_fit(funline, xdata, ydata, sigma=dy, full_output=True)
curved_fit = curve_fit(funcurv, xdata, ydata, sigma=dy, full_output=True)
curved_fitb = curve_fit(funcurvb, xdata, ydata, sigma=dy, full_output=True)

# parameter values
popt = line_fit[0]#         0.13696779
popc = curved_fit[0]#       1.06212843
popcb = curved_fitb[0]# a = 0.91566839, b = 0.02042731

# errors
pcovt = line_fit[1]#           0.0055751 (square root)
pcovc = curved_fit[1]#         0.02773344 (square root)
pcovb = curved_fitb[1]# 0.08976218, 0.01202937

# normalised residuals
res_line = line_fit[2]['fvec'] #     sum: 800.43
res_curv = curved_fit[2]['fvec']#    sum: 167.69
res_curvb = curved_fitb[2]['fvec']#  sum: 159.58

# ns = 57
# BIC line: 385.10
# BIC hyperbola: 296.00
# BIC Combined: 297.22

#%%
zscore = 1.96

import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 20})
plt.rcParams['axes.linewidth'] = 3

fig, axes = plt.subplots(1, 2, figsize=(16, 8))


axes[0].scatter(0.5*(1+rats)*(1-mAbr),0.5*(1-rats)*(1-mAbr),color='k')
axes[0].plot(0.5*(1+rats)*(1-(mAbr+zscore*mEr)),0.5*(1-rats)*(1-(mAbr+zscore*mEr)),color='k')
axes[0].plot(0.5*(1+rats)*(1-(mAbr-zscore*mEr)),0.5*(1-rats)*(1-(mAbr-zscore*mEr)),color='k')
axes[0].plot((ratios+1)/2,1-(ratios+1)/2,'k-',linewidth=3)
axes[0].plot((ratios+1)/2,1-(ratios+1)/2-popt[0], color='red',label=r'1-x-y = $\alpha$',linewidth=3)
axes[0].plot((ratios+1)/2, func_curv((ratios+1)/2,popc[0]),color='royalblue',label=r'1-x-y-$\gamma$xy = 0',linewidth=3)
#axes[0].plot((ratios+1)/2, func_curvb((ratios+1)/2,popcb[0],popcb[1]),color='orange',label=r'1-x-y-$\gamma$xy = $\beta$')
axes[0].axis([0,1,0,1])
axes[0].plot((ratios+1)/2,1-(ratios+1)/2,'k--')
axes[0].set(xlabel='X votes',ylabel='Y votes')
#axes[0].set_title('Projected data (all experiments)')
axes[0].legend()
axes[0].xaxis.set_tick_params(width=3)
axes[0].yaxis.set_tick_params(width=3)

axes[1].scatter(xdata,-res_line,color='red',label=r'1-x-y = $\alpha$')
axes[1].scatter(xdata,res_curv,marker='s',color='royalblue',label=r'1-x-y-$\gamma$xy = 0')
#axes[1].scatter(xdata,res_curvb,marker='^',color='darkorange',label=r'1-x-y-$\gamma$xy = $\beta$')
axes[1].plot(rats,rats*0,'k--')
axes[1].set(xlabel='Projected alignment',ylabel='Normalised residuals')
#axes[1].set_title('Normalised residuals')
axes[1].legend()

#plt.savefig('curvature_analysis.svg',format='svg',dpi=300)
plt.show()